import numpy as np
from scipy.stats import ortho_group


def weighted_dot(u, v=None, A=None):
    """Compute the weighted inner product <u,v>_A=u^T A v."""
    if v is None:
        v = u
    if A is None:
        A = np.eye(len(v))
    return ((u.reshape(1, -1) @ A) @ v.reshape(-1, 1)).flatten()[0]


def discretize_cmap(cmap, num_colors, invert_cmap=False, margin=1):
    """Takes a cmap from plt and extracts num_colors many equal-sized colors."""
    cmap_list = [cmap(i) for i in range(cmap.N)][margin:-margin]
    if invert_cmap:
        cmap_list = cmap_list[::-1]
    cmap_idxs = [int(x) for x in list(np.linspace(0, len(cmap_list)-1, num_colors))]
    colors = [cmap_list[idx] for idx in cmap_idxs]
    return colors


def draw_spherical(dim, rad):
    """Generates a single random draw from a uniform distribution over the sphere in dimension dim with
    radius rad"""
    x = np.random.standard_normal(dim)
    x *= rad / np.linalg.norm(x)
    return x


def get_max_eta(conf_eta_list, r_sq_stat):
    """Returns the largest possible eta that can yield the given combination of generalized confounding strength and statistical signal."""
    if type(conf_eta_list) != list:
        conf_eta_list = [conf_eta_list]
    eta_upper_bound_list = [(conf_eta - conf_eta ** 2) * r_sq_stat for conf_eta in conf_eta_list]
    return np.min(np.array(eta_upper_bound_list))


def get_params_from_generalized_conf(conf_eta, r_sq_stat, eta=None):
    """Given the generalized confounding strength measure conf_eta and a causal signal r_sq_stat, generates a random
    tuple of old confounding strength conf and inner product eta that yield the specified new confounding strength."""
    if eta is None:
        ortho_len_sq = np.random.exponential(5)
    else:
        assert eta <= (conf_eta - conf_eta ** 2) * r_sq_stat, rf"$\eta$ cannot be larger than {(conf_eta - conf_eta ** 2) * r_sq_stat}!"
        ortho_len_sq = (conf_eta - conf_eta ** 2) * r_sq_stat - eta
    omega_sq = conf_eta ** 2 * r_sq_stat + ortho_len_sq
    r_sq = (1 - conf_eta) ** 2 * r_sq_stat + ortho_len_sq

    conf = omega_sq / (r_sq + omega_sq)
    if eta is None:
        eta = (r_sq_stat - r_sq - omega_sq) / 2
    return conf, eta


def generate_causal_params_for_fixed_statistical(d, l, r_sq_stat, sigma_sq_stat, sigma_sq, conf_strength, eta=None, theta=None, test_result=True, M_params=None):
    """Generates the hyperparameters for a causal model, which satisfies the given constraints of having a fixed corresponding statistical model
    (r_sq_stat, sigma_sq_stat), fixed dimensions (d, l), fixed confounding strength (conf_strength), and fixed angle between causal signal beta and
    confounding vector Gamma (either specified via inner product eta or angle theta)."""
    assert sigma_sq_stat >= sigma_sq, "Statistical noise sigma_sq_stat cannot be smaller than causal noise sigma_sq!"
    assert (eta is None) ^ (theta is None), "Exactly on of eta (inner product) or theta (angle) must be specified!"
    if conf_strength == .5 and theta == np.pi:
        assert r_sq_stat == 0, "Confounding strength .5 and angle pi require statistical signal constant 0."
    if theta is not None:
        assert (0 <= theta) and (theta <= np.pi), "Angle theta must be between 0 and pi!"

    # 1. Choose M randomly such that MM^T=I
    if M_params is None:
        U, V = ortho_group.rvs(d), ortho_group.rvs(l)
        Lambda = np.concatenate((np.eye(d), np.zeros((d, l-d))), axis=1)
        M = (U @ Lambda) @ V.T
    else:
        U, V, M = M_params

    if conf_strength == 0:
        alpha = np.zeros(l)
        beta = draw_spherical(d, np.sqrt(r_sq_stat))
    else:
        # 2. Use the angle (eta/theta) together with equations for r_sq_stat and conf_strength to define r_sq and omega_sq
        if eta is not None:
            omega_sq = conf_strength * r_sq_stat - eta
            r_sq = r_sq_stat - omega_sq - 2 * eta
            theta = np.arccos(np.clip(eta / np.sqrt(r_sq * omega_sq), a_min=-1, a_max=1)) # From now on, continue with the angle
        elif theta is not None:
            if theta == np.pi / 2:
                omega_sq = conf_strength * r_sq_stat
                r_sq = r_sq_stat - omega_sq
            else:
                omega_sq = r_sq_stat / (1 / conf_strength + 2 * np.cos(theta) * np.sqrt((1 - conf_strength) / conf_strength))
                r_sq = (1 - conf_strength) / conf_strength * omega_sq
        # 3. Use equation for sigma_sq_stat to define s_sq
        s_sq = sigma_sq_stat - sigma_sq

        # 4. Choose appropriate random alpha to satisfy ||alpha||^2 = omega_sq + s_sq and ||Gamma||^2 = omega_sq, where Gamma=M@alpha
        # Draw components of alpha represented in an orthonormal basis of M^TM from the uniform sphere and scale appropriately
        alpha_M = draw_spherical(dim=d, rad=np.sqrt(omega_sq))
        alpha_MC = draw_spherical(dim=l-d, rad=np.sqrt(s_sq))
        alpha = V @ np.concatenate([alpha_M, alpha_MC])
        Gamma = np.linalg.pinv(M).T @ alpha
        assert np.isclose(np.linalg.norm(Gamma)**2, omega_sq), f"||Gamma||^2={np.linalg.norm(Gamma)**2}!=omega^2={omega_sq}"
        assert np.isclose(np.linalg.norm(alpha)**2 - Gamma.T @ (M @ M.T) @ Gamma , s_sq), f"Gamma noise is wrong"


        # 5. Choose appropriate random beta to satisfy ||beta||^2 = r_sq and angle theta between beta and Gamma
        # Generate random vector w with same norm as Gamma which is orthogonal to Gamma
        if conf_strength == 1:
            beta = np.zeros(d)
        else:
            w = draw_spherical(d, 1)
            w = w - np.inner(w, Gamma) / omega_sq * Gamma # Gram-Schmidt
            w *= np.sqrt(omega_sq) / np.linalg.norm(w) # Scale to same length
            # Define beta as a linear combination of Gamma and w, then rescale
            beta = np.cos(theta) * Gamma + np.sin(theta) * w
            beta *= np.sqrt(r_sq) / np.linalg.norm(beta)

        # 6. Check if the causal model satisfies all constraints
        if test_result:
            assert np.isclose(r_sq_stat, np.linalg.norm(beta + Gamma)**2, atol=1e-7), f"Wrong statistical signal r_sq_stat={r_sq_stat}!={np.linalg.norm(beta + Gamma)**2}"
            assert np.isclose(sigma_sq_stat, sigma_sq + np.linalg.norm(alpha)**2 - np.linalg.norm(Gamma)**2, atol=1e-7), "Wrong statistical noise sigma_sq_stat!"
            assert np.isclose(conf_strength, np.linalg.norm(Gamma)**2 / (np.linalg.norm(beta)**2 + np.linalg.norm(Gamma)**2), atol=1e-7), "Wrong confounding strength!"
            if eta is not None:
                assert np.isclose(eta, np.inner(beta, Gamma), atol=1e-7), "Wrong inner product eta!"
            else:
                theta_test = np.arccos(np.clip(np.inner(beta, Gamma) / (np.linalg.norm(beta) * np.linalg.norm(Gamma)),
                                               a_min=-1, a_max=1))
                assert np.isclose(theta, theta_test, atol=1e-7), f"Wrong angle theta!{theta}!={theta_test}"
    return M, alpha, beta, sigma_sq